use futures::future::join_all;
use screenpipe_audio::core::device::default_input_device;
use screenpipe_audio::core::engine::AudioTranscriptionEngine;
use screenpipe_audio::speaker::embedding::EmbeddingExtractor;
use screenpipe_audio::speaker::embedding_manager::EmbeddingManager;
use screenpipe_audio::speaker::prepare_segments;
use screenpipe_audio::transcription::stt::SAMPLE_RATE;
use screenpipe_audio::transcription::whisper::model::{
    create_whisper_context_parameters, download_whisper_model,
};
use screenpipe_audio::vad::{silero::SileroVad, VadEngine};
use screenpipe_audio::{resample, stt, AudioInput};
use screenpipe_core::Language;
use std::path::PathBuf;
use std::sync::Arc;
use strsim::levenshtein;
use tokio::sync::Mutex;
use tracing::debug;
use whisper_rs::WhisperContext;

#[tokio::test]
#[ignore]
async fn test_transcription_accuracy() {
    // Initialize tracing
    // tracing_subscriber::fmt()
    //     .with_max_level(tracing::Level::DEBUG)
    //     .init();

    debug!("starting transcription accuracy test");

    // Setup
    let test_cases = vec![
        (
            "test_data/accuracy1.wav",
            r#"yo louis, here's the tldr of that mind-blowing meeting. bob's cat walked across his keyboard 3 times. productivity increased by 200%. sarah's virtual background glitched, revealing she was actually on a beach. no one noticed. you successfully pretended to be engaged while scrolling twitter. achievement unlocked! 7 people said "you're on mute" in perfect synchronization. new world record. meeting could've been an email. shocking. key takeaway: we're all living in a simulation, and the devs are laughing. peace out, llama3.2:3b-instruct-q4_k_m"#,
        ),
        (
            "test_data/accuracy2.wav",
            r#"bro - got some good stuff from screenpipe here's the lowdown on your day, you productivity ninja: absolutely demolished that 2-hour coding sesh on the new feature. the keyboard is still smoking, bro! crushed 3 client calls like a boss. they're probably writing love letters to you as we speak, make sure to close john tomorrow 8.00 am according to our notes, let the cash flow in! spent 45 mins on slack. 90% memes, 10% actual work. perfectly balanced, as all things should bewatched a rust tutorial. way to flex those brain muscles, you nerd! overall, you're killing it! 80% of your time on high-value tasks. the other 20%? probably spent admiring your own reflection, you handsome devil. ps: seriously, quit tiktok. your fbi agent is getting bored watching you scroll endlessly. what's the plan for tomorrow? more coding? more memes? world domination? generated by your screenpipe ai assistant (who's definitely not planning to take over the world... yet)"#,
        ),
        (
            "test_data/accuracy3.wav",
            r#"again, screenpipe allows you to get meeting summaries, locally, without leaking data to openai, with any apps, like whatsapp, meet, zoom, etc. and it's open source at github.com/mediar-ai/screenpipe"#,
        ),
        (
            "test_data/accuracy4.wav",
            r#"eventually but, i mean, i feel like but, i mean, first, i mean, you think your your vision smart will be interesting because, yeah, you install once. you pay us, you install once. that that yours. so, basically, all the time microsoft explained, you know, ms office, long time ago, you just buy the the the software that you can using there forever unless you wanna you wanna update upgrade is the better version. right? so it's a little bit, you know"#,
        ),
        (
            "test_data/accuracy5.wav",
            r#"thank you. yeah. so i cannot they they took it, refresh because of my one set top top time. and, also, second thing is, your byte was stolen. by the time?"#,
        ),
        // Add more test cases as needed
    ];

    let context_params =
        create_whisper_context_parameters(Arc::new(AudioTranscriptionEngine::WhisperTinyQuantized))
            .unwrap();

    let quantized_path =
        download_whisper_model(Arc::new(AudioTranscriptionEngine::WhisperTinyQuantized)).unwrap();
    let whisper_context = Arc::new(
        WhisperContext::new_with_params(&quantized_path.to_string_lossy(), context_params)
            .expect("failed to load model"),
    );

    let vad_engine: Arc<Mutex<Box<dyn VadEngine + Send>>> =
        Arc::new(Mutex::new(Box::new(SileroVad::new().await.unwrap())));

    let mut tasks = Vec::new();

    let project_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
    let segmentation_model_path = project_dir
        .join("models")
        .join("pyannote")
        .join("segmentation-3.0.onnx");

    let embedding_model_path = project_dir
        .join("models")
        .join("pyannote")
        .join("wespeaker_en_voxceleb_CAM++.onnx");

    let embedding_extractor = Arc::new(std::sync::Mutex::new(
        EmbeddingExtractor::new(embedding_model_path.to_str().unwrap()).unwrap(),
    ));

    for (audio_file, expected_transcription) in test_cases {
        let whisper_context = whisper_context.clone();
        let vad_engine = Arc::clone(&vad_engine);

        let embedding_extractor = Arc::clone(&embedding_extractor);
        let embedding_manager = EmbeddingManager::new(usize::MAX);
        let segmentation_model_path = segmentation_model_path.clone();

        let task = tokio::spawn(async move {
            let audio_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(audio_file);
            let audio_data =
                screenpipe_audio::pcm_decode(&audio_path).expect("Failed to decode audio file");

            let audio_input = AudioInput {
                data: Arc::new(audio_data.0),
                sample_rate: 44100, // hardcoded based on test data sample rate
                channels: 1,
                device: Arc::new(default_input_device().unwrap()),
            };

            let audio_data = if audio_input.sample_rate != SAMPLE_RATE {
                match resample(
                    audio_input.data.as_ref(),
                    audio_input.sample_rate,
                    SAMPLE_RATE,
                ) {
                    Ok(data) => data,
                    Err(e) => {
                        panic!("Error resampling audio: {:?}", e);
                    }
                }
            } else {
                audio_input.data.as_ref().to_vec()
            };

            let (mut segments, _) = prepare_segments(
                &audio_data,
                vad_engine.clone(),
                &segmentation_model_path,
                embedding_manager,
                embedding_extractor,
                &audio_input.device.name,
            )
            .await
            .unwrap();

            let mut transcription = String::new();
            while let Some(segment) = segments.recv().await {
                let transcript = stt(
                    &segment.samples,
                    audio_input.sample_rate,
                    &audio_input.device.to_string(),
                    Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo),
                    None,
                    vec![Language::English],
                    whisper_context.clone(),
                )
                .await
                .unwrap();

                transcription.push_str(&transcript);
            }

            let distance = levenshtein(expected_transcription, &transcription.to_lowercase());
            let accuracy = 1.0 - (distance as f64 / expected_transcription.len() as f64);

            (audio_file, expected_transcription, transcription, accuracy)
        });

        tasks.push(task);
    }

    let results = join_all(tasks).await;

    let mut total_accuracy = 0.0;
    let mut total_tests = 0;

    for result in results {
        let (audio_file, expected_transcription, transcription, accuracy) = result.unwrap();

        println!("file: {}", audio_file);
        println!("expected: {}", expected_transcription);
        println!("actual: {}", transcription);
        println!("accuracy: {:.2}%", accuracy * 100.0);
        // println!();

        total_accuracy += accuracy;
        total_tests += 1;
    }

    let average_accuracy = total_accuracy / total_tests as f64;
    println!("average accuracy: {:.2}%", average_accuracy * 100.0);

    assert!(average_accuracy > 0.55, "average accuracy is below 55%");
}
